#!/usr/bin/env python3
import argparse
import boto3
import os
import queue
import sys
import threading
import time
import copy

asg = boto3.client('autoscaling')
ec2 = boto3.client('ec2')
ecs = boto3.client('ecs')


def main(argv):
    options = parse_arguments(argv)
    conf = None

    info('fetching autoscaling group %s...', options.cluster)
    group = fetch_auto_scaling_group(options.cluster)

    info('fetching launch configuration %s...', options.cluster)
    conf = fetch_launch_configuration(group['LaunchConfigurationName'])

    instances = filter_instances_to_cycle(group['Instances'], conf['LaunchConfigurationName'], options.force)
    step_size = compute_step_size(options.number, len(instances))

    info('cycling instances %s at a time (%s exist in the group)', step_size, len(group['Instances']))
    info('using image %s', conf['ImageId'])

    if options.dry_run:
        print_summary(instances)
        return

    new_instances = [ ]
    while True:
        instances, conf = run(options, step_size, conf, options.force, new_instances)
        new_instances += instances

def run(options, step_size=1, launch_config=None, force=False, ignore_instances=None):
    group = fetch_auto_scaling_group(options.cluster)
    conf = fetch_launch_configuration(group['LaunchConfigurationName'])

    if launch_config is not None and conf['LaunchConfigurationName'] != launch_config['LaunchConfigurationName']:
        error('launch configuration changed while cycling EC2 instances, aborting!')
        sys.exit(1)

    instances = filter_instances_to_cycle(group['Instances'], conf['LaunchConfigurationName'], force, ignore_instances)
    step_size = min(step_size, len(instances))

    if not instances:
        sys.exit(0)

    old_instances = instances[:step_size]
    desired_capacity = group['DesiredCapacity']
    max_size = group['MaxSize']
    count = len(old_instances)

    # Increase the auto scaling group capacity to start new instances,
    # the max size is also modified if it was too low to raise the desired
    # capacity.
    d = desired_capacity + count
    m = max_size
    if d > m:
        m = d
    info('increasing desired capacity of auto scaling group to %s', d)
    asg.update_auto_scaling_group(
        AutoScalingGroupName = options.cluster,
        DesiredCapacity      = d,
        MaxSize              = m,
    )

    # Look for `count` new instances that will be coming up in the auto
    # scaling group (instances that didn't exist before).
    info('looking for the ids of new EC2 instances...')
    new_instances = [ ]
    while len(new_instances) < count:
        time.sleep(1)
        starting = fetch_auto_scaling_instances(options.cluster)
        starting = filter_new_instances(starting, group['Instances'])
        for instance in starting:
            if instance not in new_instances:
                new_instances.append(instance)
                info('%s - launched', instance)

    # Wait for all new instances to become ready.
    info('waiting for new EC2 instances to become ready...')
    tmp_instances = set(new_instances)
    while tmp_instances:
        time.sleep(1)
        for instance in asg.describe_auto_scaling_instances(
                InstanceIds = list(tmp_instances),
        )['AutoScalingInstances']:
            if instance['LifecycleState'] == 'InService':
                info('%s - ready', instance['InstanceId'])
                tmp_instances.remove(instance['InstanceId'])

    # Wait for all new instances to be in a running state.
    info('waiting for new EC2 instances to be in a running state...')
    tmp_instances = set(new_instances)
    while tmp_instances:
        time.sleep(1)
        for instance in ec2.describe_instance_status(
                InstanceIds = list(tmp_instances),
                Filters     = [
                    { 'Name': 'instance-state-name',          'Values': ['running'] },
                    { 'Name': 'instance-status.reachability', 'Values': ['passed']  },
                ],
        )['InstanceStatuses']:
            info('%s - running', instance['InstanceId'])
            tmp_instances.remove(instance['InstanceId'])

    container_instance_id_map = build_container_instance_map(options.cluster)

    # Drain old instances
    info('draining ECS container instances...')
    for instance in old_instances:
        info('%s - draining', instance)
    ecs.update_container_instances_state(
        cluster=options.cluster,
        containerInstances=[container_instance_id_map[instance] for instance in old_instances],
        status='DRAINING')

    draining_instances = copy.copy(old_instances)
    while draining_instances:
        time.sleep(5)
        ci_info = ecs.describe_container_instances(
            cluster=options.cluster,
            containerInstances=[container_instance_id_map[instance] for instance in draining_instances])
        for container_instance in ci_info['containerInstances']:
            if container_instance['runningTasksCount'] == 0:
                draining_instances.remove(container_instance['ec2InstanceId'])
                info('%s - drained', container_instance['ec2InstanceId'])

    # Terminates the old instances that aren't necessary anymore (the ones
    # that were picked by the iterator).
    for instance in old_instances:
        asg.terminate_instance_in_auto_scaling_group(
            InstanceId                     = instance,
            ShouldDecrementDesiredCapacity = True,
        )
        info('%s - terminated', instance)

    return new_instances, conf

def log(msg, *args):
    msg += '\n'
    if args:
        msg %= args
    sys.stdout.write(msg)

def warn(msg, *args):
    log('warn: ' + msg, *args)

def error(msg, *args):
    log('error: ' + msg, *args)
    sys.exit(1)

def info(msg, *args):
    log('==> roll: ' + msg, *args)

def parse_arguments(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--cluster', default='default', help="The name of the cluster to roll EC2 instances in")
    parser.add_argument('-n', '--number', default='25%', help="The number of EC2 instances to roll at once, may be absolute or a percentage")
    parser.add_argument('--force', action='store_true', help="When specified the program will cycle instances even if the launch configuration did not change")
    parser.add_argument('--dry-run', action='store_true', help="When specified the program stops before making any changes")
    return parser.parse_args(argv)

def print_summary(instances):
    print()
    print(' The following instances would by cycled:')
    for instance in instances:
        print('  *', instance)
    print()

def compute_step_size(number, count):
    step_size = 0

    if number.endswith('%'):
        percentage = float(number[:-1])
        if percentage <= 0 or percentage > 100:
            error('invalid percentage of instances to roll %s', percentage)
        step_size = round(count * (percentage / 100.0))
    else:
        step_size = int(number)

    return min(max(1, step_size), count)

def fetch_auto_scaling_group(name):
    group = asg.describe_auto_scaling_groups(
        AutoScalingGroupNames=[name],
    )
    return group['AutoScalingGroups'][0]

def fetch_auto_scaling_instances(name):
    group = fetch_auto_scaling_group(name)
    return group['Instances']

def fetch_launch_configuration(name):
    conf = asg.describe_launch_configurations(
        LaunchConfigurationNames=[name],
    )
    return conf['LaunchConfigurations'][0]

def filter_instances_to_cycle(instances, launch_config, force=False, ignore_instances=None):
    to_cycle = [ ]

    for instance in instances:
        if instance.get('LifecycleState') not in ('Pending', 'Pending:Wait', 'Pending:Proceed', 'Quarantined', 'InService'):
            continue
        if not force and instance.get('LaunchConfigurationName') == launch_config:
            continue
        to_cycle.append(instance['InstanceId'])

    if ignore_instances is not None:
        to_cycle = [i for i in to_cycle if i not in ignore_instances]

    to_cycle.sort()
    return to_cycle

def filter_new_instances(instances, group_instances):
    group_instances = dict((i['InstanceId'], True) for i in group_instances)
    new_instances = [ ]

    for instance in instances:
        if instance.get('LifecycleState') not in ('Pending', 'Pending:Wait', 'Pending:Proceed', 'Quarantined', 'InService'):
            continue
        if group_instances.get(instance['InstanceId']):
            continue
        new_instances.append(instance['InstanceId'])

    new_instances.sort()
    return new_instances

def build_container_instance_map(cluster):
    container_instances = ecs.list_container_instances(cluster=cluster)
    full_info = ecs.describe_container_instances(cluster=cluster,
                                                 containerInstances=container_instances['containerInstanceArns'])

    container_instance_map = {}
    for ci in full_info['containerInstances']:
        container_instance_map[ci['ec2InstanceId']] = ci['containerInstanceArn']
    return container_instance_map

def iter_instance_groups(instances, group_size):
    while instances:
        n = min(group_size, len(instances))
        instance_group, instances = instances[:n], instances[n:]
        yield instance_group

if __name__ == '__main__':
    main(sys.argv[1:])
